package com.gitee.fastmybatis.core.ext;

import com.gitee.fastmybatis.core.FastmybatisConfig;
import com.gitee.fastmybatis.core.FastmybatisConstants;
import com.gitee.fastmybatis.core.ext.code.client.ClassClient;
import com.gitee.fastmybatis.core.ext.exception.GenCodeException;
import com.gitee.fastmybatis.core.ext.exception.MapperFileException;
import com.gitee.fastmybatis.core.ext.info.ColumnInfo;
import com.gitee.fastmybatis.core.ext.info.TableInfo;
import com.gitee.fastmybatis.core.ext.jpa.ConditionDefinition;
import com.gitee.fastmybatis.core.ext.jpa.ConditionUtil;
import com.gitee.fastmybatis.core.ext.jpa.ConditionWrapper;
import com.gitee.fastmybatis.core.ext.jpa.JpaKeyword;
import com.gitee.fastmybatis.core.ext.spi.ClassSearch;
import com.gitee.fastmybatis.core.ext.spi.SpiContext;
import com.gitee.fastmybatis.core.query.Joint;
import com.gitee.fastmybatis.core.query.Operator;
import com.gitee.fastmybatis.core.util.IOUtil;
import com.gitee.fastmybatis.core.util.MybatisFileUtil;
import com.gitee.fastmybatis.core.util.StringUtil;
import org.apache.ibatis.annotations.*;
import org.apache.ibatis.logging.Log;
import org.apache.ibatis.logging.LogFactory;
import org.apache.ibatis.session.SqlSessionFactory;
import org.dom4j.Attribute;
import org.dom4j.Document;
import org.dom4j.DocumentException;
import org.dom4j.Element;
import org.dom4j.io.SAXReader;
import org.xml.sax.SAXException;

import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.lang.reflect.Method;
import java.nio.charset.StandardCharsets;
import java.util.*;

/**
 * mapper构建
 *
 * @author tanghc
 */
public class MapperLocationsBuilder {

    private static final Log LOG = LogFactory.getLog(MapperLocationsBuilder.class);

    private static final String DEFAULT_COMMON_SQL = "fastmybatis/commonSql.xml";

    private final Map<String, MyBatisResource> mybatisMapperStore = new HashMap<>();

    private FastmybatisConfig config;

    private List<String> mapperNames = Collections.emptyList();

    private Set<Class<?>> mapperClasses = new HashSet<>(64);

    private String dialect;

    public MapperLocationsBuilder() {
        this(new FastmybatisConfig());
    }

    public MapperLocationsBuilder(FastmybatisConfig config) {
        Objects.requireNonNull(config, "config can not null");
        this.config = config;
    }

    public MyBatisResource[] build(Set<Class<?>> mapperClasses, List<MyBatisResource> myBatisResources, String dialect) {
        for (MyBatisResource myBatisResource : myBatisResources) {
            // XxDao.xml
            String filename = myBatisResource.getFilename();
            mybatisMapperStore.put(filename, myBatisResource);
        }
        this.dialect = dialect.replaceAll("\\s", "").toLowerCase();
        try {
            // 检查是否将sql写在了注解中
            if (config.isDisableSqlAnnotation()) {
                checkSqlAnnotationOnMapper(mapperClasses);
            }

            return this.buildMapperLocations(mapperClasses);
        } catch (Exception e) {
            LOG.error("构建mapper失败", e);
            throw new MapperFileException(e);
        } finally {
            distroy();
        }
    }

    public void afterBuild(SqlSessionFactory sqlSessionFactory) {
        for (Class<?> mapperClass : mapperClasses) {
            ExtContext.addSqlSessionFactory(mapperClass, sqlSessionFactory);
        }
    }

    public MyBatisResource[] build(String basePackage, List<MyBatisResource> myBatisResources, String dialect) {
        String[] basePackages = StringUtil.tokenizeToStringArray(basePackage,
                StringUtil.CONFIG_LOCATION_DELIMITERS);
        ClassSearch classSearch = SpiContext.getClassSearch();
        try {
            Set<Class<?>> clazzsSet = classSearch.search(Object.class, basePackages);
            return build(clazzsSet, myBatisResources, dialect);
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }


    /**
     * 禁止将sql写在注解中
     *
     * @param clazzsSet
     */
    private void checkSqlAnnotationOnMapper(Set<Class<?>> clazzsSet) {

        if (null == clazzsSet || 0 == clazzsSet.size()) {
            return;
        }

        for (Class<?> mapperClass : clazzsSet) {
            Method[] methods = mapperClass.getMethods();
            if (null == methods || methods.length == 0) {
                continue;
            }

            for (Method m : methods) {
                checkInsert(mapperClass, m);
                checkDelete(mapperClass, m);
                checkSelect(mapperClass, m);
                checkUpdate(mapperClass, m);
            }
        }

    }


    private void checkInsert(Class<?> mapperClass, Method m) {
        Insert anno = m.getAnnotation(Insert.class);
        if (null != anno) {
            throw new IllegalStateException("本项目禁止将sql写在Mybatis注解中.问题Mapper:" + mapperClass.getName() + ",问题方法：" + m.getName());
        }
        InsertProvider annoSp = m.getAnnotation(InsertProvider.class);
        if (null != annoSp) {
            throw new IllegalStateException("本项目禁止使用InsertProvider注解.问题Mapper:" + mapperClass.getName() + ",问题方法：" + m.getName());
        }
    }

    private void checkSelect(Class<?> mapperClass, Method m) {
        Select anno = m.getAnnotation(Select.class);
        if (null != anno) {
            throw new IllegalStateException("本项目禁止将sql写在Mybatis注解中.问题Mapper:" + mapperClass.getName() + ",问题方法：" + m.getName());
        }
        SelectProvider annoSp = m.getAnnotation(SelectProvider.class);
        if (null != annoSp) {
            throw new IllegalStateException("本项目禁止使用SelectProvider注解.问题Mapper:" + mapperClass.getName() + ",问题方法：" + m.getName());
        }
    }

    private void checkUpdate(Class<?> mapperClass, Method m) {
        Update anno = m.getAnnotation(Update.class);
        if (null != anno) {
            throw new IllegalStateException("本项目禁止将sql写在Mybatis注解中.问题Mapper:" + mapperClass.getName() + ",问题方法：" + m.getName());
        }

        UpdateProvider annoSp = m.getAnnotation(UpdateProvider.class);
        if (null != annoSp) {
            throw new IllegalStateException("本项目禁止使用UpdateProvider注解.问题Mapper:" + mapperClass.getName() + ",问题方法：" + m.getName());
        }
    }

    private void checkDelete(Class<?> mapperClass, Method m) {
        Delete anno = m.getAnnotation(Delete.class);
        if (null != anno) {
            throw new IllegalStateException("本项目禁止将sql写在Mybatis注解中.问题Mapper:" + mapperClass.getName() + ",问题方法：" + m.getName());
        }

        DeleteProvider annoSp = m.getAnnotation(DeleteProvider.class);
        if (null != annoSp) {
            throw new IllegalStateException("本项目禁止使用DeleteProvider注解.问题Mapper:" + mapperClass.getName() + ",问题方法：" + m.getName());
        }
    }


    private void distroy() {
        mybatisMapperStore.clear();
        ExtContext.clearTableInfo();
    }

    private MyBatisResource getMapperFile(String mapperFileName) {
        return mybatisMapperStore.get(mapperFileName);
    }

    private MyBatisResource[] buildMapperLocations(Set<Class<?>> clazzsSet) {
        ExtContext.addMapperClass(clazzsSet);
        mapperClasses.addAll(clazzsSet);
        mapperNames = this.buildMapperNames(clazzsSet);

        List<MyBatisResource> mapperLocations = this.buildMapperResource(clazzsSet);

        this.addUnmergedResource(mapperLocations);

        this.addCommonSqlClasspathMapper(mapperLocations);

        return mapperLocations.toArray(new MyBatisResource[0]);
    }

    private List<MyBatisResource> buildMapperResource(Set<Class<?>> clazzsSet) {
        int classCount = clazzsSet.size();
        if (classCount == 0) {
            return new ArrayList<>();
        }
        final MyBatisResource templateResource = this.buildTemplateResource(this.getDbName());
        LOG.debug("使用模板:" + templateResource);
        final String globalVmLocation = this.config.getGlobalVmLocation();
        final ClassClient codeClient = new ClassClient(config);
        final List<MyBatisResource> mapperLocations = new ArrayList<>(classCount);

        long startTime = System.currentTimeMillis();
        try {
            String templateContent = templateResource.getContent();
            for (Class<?> daoClass : clazzsSet) {
                String xml = codeClient.genMybatisXml(daoClass, templateContent, globalVmLocation);
                xml = mergeExtMapperFile(daoClass, xml);
                saveMapper(daoClass.getSimpleName() + FastmybatisConstants.XML_SUFFIX, xml);
                mapperLocations.add(MyBatisResource.build(xml, daoClass));
            }

            long endTime = System.currentTimeMillis();
            LOG.debug("生成Mapper内容总耗时：" + (endTime - startTime) / 1000.0 + "秒");
            return mapperLocations;
        } catch (Exception e) {
            LOG.error(e.getMessage(), e);
            throw new GenCodeException(e);
        }

    }

    private List<String> buildMapperNames(Set<Class<?>> clazzsSet) {
        List<String> list = new ArrayList<>(clazzsSet.size());
        for (Class<?> mapperClass : clazzsSet) {
            list.add(mapperClass.getSimpleName());
        }
        return list;
    }

    /**
     * 保存mapper到本地文件夹
     *
     * @throws IOException
     */
    private void saveMapper(String filename, final String content) throws IOException {
        String saveDir = config.getMapperSaveDir();
        if (StringUtil.hasText(saveDir)) {
            String path = saveDir + "/" + filename;
            LOG.debug("保存mapper文件到" + path);
            try (OutputStream out = new FileOutputStream(path)) {
                IOUtil.copy(IOUtil.toInputStream(content, StandardCharsets.UTF_8), out);
            } catch (IOException e) {
                throw e;
            }
        }
    }

    private MyBatisResource buildTemplateResource(String dialect) {
        // mysql.vm
        String templateFileName = this.buildTemplateFileName(dialect);
        // 优先使用classpath根目录下的vm模板
        MyBatisResource myBatisResource = MyBatisResource.buildFromClasspath(templateFileName);
        if (myBatisResource.exists()) {
            return myBatisResource;
        }
        String templateClasspath = config.getTemplateClasspath();
        if (StringUtil.isEmpty(templateClasspath)) {
            templateClasspath = FastmybatisConstants.DEFAULT_CLASS_PATH;
        }
        // 返回格式：classpath路径 + 数据库名称 + 文件后缀
        // 如：/fastmybatis/tpl/mysql.vm
        String location = templateClasspath + templateFileName;
        return MyBatisResource.buildFromClasspath(location);
    }

    /**
     * 构建文件名
     */
    private String buildTemplateFileName(String dialect) {
        dialect = dialect.replaceAll("\\s", "").toLowerCase();
        return dialect + FastmybatisConstants.TEMPLATE_SUFFIX;
    }

    private void addCommonSqlClasspathMapper(List<MyBatisResource> mapperLocations) {
        String commonSqlClasspath = config.getCommonSqlClasspath();
        if (StringUtil.isEmpty(commonSqlClasspath)) {
            commonSqlClasspath = DEFAULT_COMMON_SQL;
        }
        MyBatisResource myBatisResource = MyBatisResource.buildFromClasspath(commonSqlClasspath);
        mapperLocations.add(myBatisResource);
    }

    /**
     * 添加其它mapper
     */
    private void addUnmergedResource(List<MyBatisResource> mapperLocations) {
        Collection<MyBatisResource> mapperResourceDefinitions = this.mybatisMapperStore.values();
        for (MyBatisResource mapperResourceDefinition : mapperResourceDefinitions) {
            if (mapperResourceDefinition.isMerged()) {
                continue;
            }
            LOG.debug("加载未合并Mapper：" + mapperResourceDefinition.getFilename());
            mapperLocations.add(mapperResourceDefinition);
        }
    }

    /**
     * 合并扩展mapper文件内容
     *
     * @throws DocumentException
     * @throws IOException
     */
    private String mergeExtMapperFile(Class<?> mapperClass, String xml) throws IOException, DocumentException {
        // 自定义文件
        String mapperFileName = mapperClass.getSimpleName() + FastmybatisConstants.XML_SUFFIX;
        // 先找跟自己同名的xml，如:UserMapper.java -> UserMapper.xml
        MyBatisResource myBatisResource = this.getMapperFile(mapperFileName);
        StringBuilder extXml = new StringBuilder();

        if (myBatisResource != null) {
            // 追加内容
            String extFileContent = MybatisFileUtil.getExtFileContent(myBatisResource.getInputStream());
            extXml.append(extFileContent);

            myBatisResource.setMerged(true);
        }
        // 再找namespace一样的xml
        String otherMapperXml = this.buildOtherMapperContent(mapperClass, this.mybatisMapperStore.values());
        extXml.append(otherMapperXml);

        xml = xml.replace(FastmybatisConstants.EXT_MAPPER_PLACEHOLDER, extXml.toString());

        xml = mergeFindBySql(mapperClass, xml);

        return xml;
    }

    private String mergeFindBySql(Class<?> mapperClass, String xml) {
        TableInfo tableInfo = ExtContext.getTableInfo(mapperClass);
        StringBuilder findBySql = new StringBuilder();
        Method[] methods = mapperClass.getMethods();
        String FIND_BY_TPL =
                "<select id=\"{methodName}\" resultMap=\"baseResultMap\">\n" +
                        " {bind}\n" +
                        "    SELECT\n" +
                        "        <include refid=\"baseColumns\"/>\n" +
                        "    FROM {tableName} t\n" +
                        "    <where>\n" +
                        "        {condition}\n" +
                        "        {andLogicDelete}\n" +
                        "    </where>\n" +
                        "    {order}\n" +
                        "</select>";
        for (Method method : methods) {
            if (method.isDefault()) {
                continue;
            }
            String methodName = method.getName();
            if (xml.contains(methodName)) {
                continue;
            }
            if (method.getAnnotations().length > 0) {
                continue;
            }
            if (!methodName.startsWith(ConditionUtil.FIND_BY_PREFIX)) {
                continue;
            }
            ConditionDefinition conditionDefinition = ConditionUtil.getConditions(methodName);
            List<ConditionWrapper> conditions = conditionDefinition.getConditionWrappers();
            checkParam(conditions, method);

            TableSqlBlock tableSqlBlock = buildConditionBlock(tableInfo, conditions);

            String sql = FIND_BY_TPL.replace("{methodName}", methodName)
                    .replace("{bind}", tableSqlBlock.bindBlock)
                    .replace("{tableName}", tableInfo.getTableName())
                    .replace("{condition}", tableSqlBlock.conditionBlack)
                    .replace("{andLogicDelete}", buildLogicDeleteCondition(tableInfo))
                    .replace("{order}", conditionDefinition.getOrderBy())
                    ;
            LOG.debug("generate findBy sql, mapper:" + mapperClass.getName() + ", sql:\n" + sql);
            findBySql.append(sql).append("\n");
        }
        return xml.replace("<!--_find_by_-->", findBySql.toString());
    }

    private void checkParam(List<ConditionWrapper> conditions, Method method) {
        long paramCount = conditions.stream()
                .filter(conditionWrapper -> !JpaKeyword.isNoParamKeyword(conditionWrapper.getKeyword()))
                .count();
        if (method.getParameters().length == 0 && paramCount > 0) {
            throw new IllegalArgumentException("Mapper方法 '" + method + "' 必须要有对应的参数，如：findByName(String name)");
        }
    }

    private TableSqlBlock buildConditionBlock(TableInfo tableInfo, List<ConditionWrapper> conditions) {
        // <bind name="pattern" value="'%' + arg0 + '%'" />
        List<String> bindBlocks = new ArrayList<>();
        List<String> orderbyBlock = new ArrayList<>();

        List<String> sql = new ArrayList<>();
        int paramIndex = 0;
        for (ConditionWrapper wrapper : conditions) {
            JpaKeyword keyword = wrapper.getKeyword();
            String column = wrapper.getColumn();
            String operate = keyword.getOperate();
            int idx = paramIndex++;
            if (keyword == JpaKeyword.between) {
                int idx2 = paramIndex++;
                operate = operate.replaceFirst("\\?", "#{arg" + idx + "}");
                operate = operate.replaceFirst("\\?", "#{arg" + idx2 + "}");
            } else if (keyword == JpaKeyword.like || keyword == JpaKeyword.not_like || keyword == JpaKeyword.containing) {
                bindBlocks.add("<bind name=\"pattern_" + idx + "\" value=\"'%' + arg" + idx + " + '%'\" />");
                operate = operate.replace("{idx}", String.valueOf(idx));
            } else if (keyword == JpaKeyword.starting_with) {
                bindBlocks.add("<bind name=\"pattern_" + idx + "\" value=\"arg" + idx + " + '%'\" />");
                operate = operate.replace("{idx}", String.valueOf(idx));
            } else if (keyword == JpaKeyword.ending_with) {
                bindBlocks.add("<bind name=\"pattern_" + idx + "\" value=\"'%' + arg" + idx + "\" />");
                operate = operate.replace("{idx}", String.valueOf(idx));
            } else if (keyword == JpaKeyword.in || keyword == JpaKeyword.not_in) {
                String foreach = "\n<foreach collection=\"arg" + idx + "\" item=\"value\" open=\"(\" separator=\",\" close=\")\">#{value}</foreach>\n";
                operate = operate.replace("?", foreach);
            } else if (keyword == JpaKeyword.ignore_case) {
                column = upperColumn(column);
                operate = operate.replace("?", upperColumn("#{arg" + idx + "}"));
            } else if (keyword == JpaKeyword.True) {
                ColumnInfo columnInfo = tableInfo.getColumnInfoByColumnName(column);
                String type = columnInfo.getType();
                String bindValue = "boolean".equalsIgnoreCase(type) ? "true" : "1";
                operate = operate.replace("?", "${" + bindValue + "}");
            } else if (keyword == JpaKeyword.False) {
                ColumnInfo columnInfo = tableInfo.getColumnInfoByColumnName(column);
                String type = columnInfo.getType();
                String bindValue = "boolean".equalsIgnoreCase(type) ? "false" : "0";
                operate = operate.replace("?", "${" + bindValue + "}");
            } else {
                operate = operate.replace("?", "#{arg" + idx + "}");
            }

            sql.add(wrapper.getJoint());
            sql.add(column);
            sql.add(operate);
        }
        return new TableSqlBlock(String.join("\n", bindBlocks),
                String.join(" ", sql), String.join(",", orderbyBlock));
    }

    private String upperColumn(String column) {
        return "UPPER(" + column + ")";
    }


    private String buildLogicDeleteCondition(TableInfo tableInfo) {
        if (!tableInfo.isHasLogicDeleteColumn()) {
            return "";
        }
        // AND t.${table.logicDeleteColumn.columnName} = ${table.logicDeleteColumn.logicNotDeleteValueString}
        List<String> expressions = Arrays.asList(
                Joint.AND.getJoint(),
                tableInfo.getLogicDeleteColumnName(),
                Operator.eq.getOperator(),
                "${" + tableInfo.getLogicNotDeleteValue() + "}"
        );
        return String.join(" ", expressions);
    }


    /**
     * 一个Mapper.java可以对应多个Mapper.xml。只要namespace相同，就会把它们的内容合并，最终形成一个完整的MapperResource<br>
     * 这样做的好处是每人维护一个文件相互不干扰，至少在提交代码是不会冲突，同时也遵循了开闭原则。
     *
     * @throws IOException
     * @throws DocumentException
     */
    private String buildOtherMapperContent(Class<?> mapperClass, Collection<MyBatisResource> mapperResourceDefinitions) throws IOException, DocumentException {
        StringBuilder xml = new StringBuilder();
        String trueNamespace = mapperClass.getName();
        for (MyBatisResource mapperResourceDefinition : mapperResourceDefinitions) {
            String filename = mapperResourceDefinition.getFilename();
            filename = filename.substring(0, filename.length() - 4);
            if (mapperResourceDefinition.isMerged() || mapperNames.contains(filename)) {
                continue;
            }
            InputStream in = mapperResourceDefinition.getInputStream();
            Document document = this.buildSAXReader().read(in);
            Element mapperNode = document.getRootElement();

            Attribute attrNamespace = mapperNode.attribute(FastmybatisConstants.ATTR_NAMESPACE);
            String namespaceValue = attrNamespace == null ? null : attrNamespace.getValue();

            if (StringUtil.isEmpty(namespaceValue)) {
                throw new MapperFileException("Mapper文件[" + mapperResourceDefinition.getFilename() + "]的namespace不能为空。");
            }

            if (trueNamespace.equals(namespaceValue)) {
                String contentXml = MybatisFileUtil.trimMapperNode(mapperNode);
                xml.append(contentXml);
                mapperResourceDefinition.setMerged(true);
            }
        }
        return xml.toString();

    }

    private SAXReader buildSAXReader() {
        SAXReader reader = new SAXReader();
        reader.setEncoding(FastmybatisConstants.ENCODE);
        try {
            reader.setFeature(FastmybatisConstants.SAXREADER_FEATURE, false);
        } catch (SAXException e) {
            LOG.error("reader.setFeature fail by ", e);
        }
        return reader;
    }

    public void setConfig(FastmybatisConfig config) {
        this.config = config;
    }

    public String getDbName() {
        return dialect;
    }

    public void setDbName(String dialect) {
        this.dialect = dialect.replaceAll("\\s", "").toLowerCase();;
    }

    public void setMapperExecutorPoolSize(int poolSize) {
        config.setMapperExecutorPoolSize(poolSize);
    }

    public FastmybatisConfig getConfig() {
        return config;
    }

    static class TableSqlBlock {
        private final String bindBlock;
        private final String conditionBlack;

        private final String orderbyBlock;

        public TableSqlBlock(String bindBlock, String conditionBlack, String orderbyBlock) {
            this.bindBlock = bindBlock;
            this.conditionBlack = conditionBlack;
            this.orderbyBlock = orderbyBlock;
        }
    }

}
